{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Graphein Protein Structure Dataloaders\n",
    "## PyTorch Geometric Datasets\n",
    "\n",
    "[API Reference](https://graphein.ai/modules/graphein.ml.html#graphein.ml.datasets.torch_geometric_dataset)\n",
    "\n",
    "Graphein provides three dataset classes for working with [PyTorch Geometric](https://pytorch-geometric.readthedocs.io/en/latest/):\n",
    "\n",
    "* [`ProteinGraphDataset`]() - For processing large datasets that can't be kept in memory\n",
    "* [`InMemoryProteinGraphDataset`]() - For smaller datasets that can be kept in memory\n",
    "* [`ProteinGraphListDataset`]() - For creating a dataset from a list of pre-computed PyTorch Geometric graphs.\n",
    "\n",
    "Both `ProteinGraphDataset` and `InMemoryGraphDataset` will take care of downloading structures from either the [RCSB PDB](https://www.rcsb.org/), [EBI AlphaFold database](https://alphafold.com/), or both!\n",
    "`ProteinGraphListDataset` is a lightweight alternative for creating a dataset from a collection of graphs you have pre-computed.\n",
    "\n",
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/a-r-j/graphein/blob/master/notebooks/dataloader_tutorial.ipynb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install graphein if necessary\n",
    "# !pip install graphein\n",
    "\n",
    "# Install torch if necessary. See https://pytorch.org/get-started/locally/\n",
    "# pip install torch==1.11.0\n",
    "\n",
    "# Install torch geometric if necessary. See: https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html\n",
    "# pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.11.0+cpu.html"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### ProteinGraphDataset\n",
    "\n",
    "[API Reference](https://graphein.ai/modules/graphein.ml.html#graphein.ml.datasets.torch_geometric_dataset.ProteinGraphDataset)\n",
    "\n",
    "`ProteinGraphDataset` will download structures from the PDB/AlphafoldDB, process the structures into graphs according to a `ProteinGraphConfig`.\n",
    "\n",
    "#### Parameters\n",
    "```python\n",
    "ProteinGraphDataset(\n",
    "        root: str,                                                             \n",
    "        # Root directory where the dataset should be saved.\n",
    "        name: str,                                                             \n",
    "        # Name of the dataset. Will be saved to ``data_$name.pt``.\n",
    "        pdb_paths:Optional[List[str]] =None,\n",
    "        # List of full path of pdb files to load.\n",
    "        pdb_codes: Optional[List[str]] = None,                                 \n",
    "        #  List of PDB codes to download and parse from the PDB.\n",
    "        uniprot_ids: Optional[List[str]] = None,                               \n",
    "        # List of Uniprot IDs to download and parse from Alphafold Database\n",
    "        graph_label_map: Optional[Dict[str, torch.Tensor]] = None,             \n",
    "        # Dictionary mapping PDB/Uniprot IDs to graph-level labels.\n",
    "        node_label_map: Optional[Dict[str, torch.Tensor]] = None,              \n",
    "        # Dictionary mapping PDB/Uniprot IDs to node-level labels.\n",
    "        chain_selection_map: Optional[Dict[str, List[str]]] = None,            \n",
    "        # Dictionary mapping PDB/Uniprot IDs to the desired chains in the PDB files\n",
    "        graphein_config: ProteinGraphConfig = ProteinGraphConfig(),            \n",
    "        # Protein graph construction config\n",
    "        graph_format_convertor: GraphFormatConvertor = GraphFormatConvertor(   \n",
    "            src_format=\"nx\", dst_format=\"pyg\"\n",
    "        ),\n",
    "        # Conversion handler for graphs\n",
    "        graph_transformation_funcs: Optional[List[Callable]] = None,           \n",
    "        # List of functions that consume a nx.Graph and return a nx.Graph. Applied to graphs after construction but before conversion to pyg\n",
    "        transform: Optional[Callable] = None,                                  \n",
    "        # A function/transform that takes in a torch_geometric.data.Data object and returns a transformed version. The data object will be transformed before every access.\n",
    "        pdb_transform: Optional[List[Callable]] = None,\n",
    "        pre_transform: Optional[Callable] = None,                              \n",
    "        # A function/transform that takes in a torch_geometric.data.Data object and returns a transformed version. The data object will be transformed before being saved to disk\n",
    "        pre_filter: Optional[Callable] = None,                                 \n",
    "        # A function that takes in a torch_geometric.data.Data object and returns a boolean value, indicating whether the data object should be included in the final dataset\n",
    "        num_cores: int = 16,                                                   \n",
    "        # Number of cores to use for multiprocessing of graph construction\n",
    "        af_version: int = 2,                                                   \n",
    "        #  Version of AlphaFoldDB structures to use,\n",
    "    )\n",
    "```\n",
    "\n",
    "\n",
    "#### Directory Structure\n",
    "Creating a ``ProteinGraphDataset`` will create two directories under ``root``:\n",
    "\n",
    "* ``root/raw`` - Contains raw PDB files which are downloaded\n",
    "* ``root/processed`` - Contains processed graphs (in ``pytorch_geometric.data.Data`` format) saved as ``$PDB.pt / $UNIPROT_ID.pt``"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from graphein.ml import ProteinGraphDataset\n",
    "import graphein.protein as gp\n",
    "\n",
    "# Create some labels\n",
    "g_labels = torch.randn([5])\n",
    "n_labels = torch.randn([5, 10])\n",
    "\n",
    "g_lab_map = {\"3eiy\": g_labels[0], \"4hhb\": g_labels[1], \"Q5VSL9\": g_labels[2], \"1lds\": g_labels[3], \"Q8W3K0\": g_labels[4]}\n",
    "node_lab_map = {\"3eiy\": n_labels[0], \"4hhb\": n_labels[1], \"Q5VSL9\": n_labels[2], \"1lds\": n_labels[3], \"Q8W3K0\": n_labels[4]}\n",
    "\n",
    "# Select some chains\n",
    "chain_selection_map = {\"4hhb\": \"A\"}\n",
    "\n",
    "\n",
    "# Create the dataset\n",
    "ds = ProteinGraphDataset(\n",
    "    root = \"../graphein/ml/datasets/test\",\n",
    "    pdb_codes=[\"3eiy\", \"4hhb\", \"1lds\"],\n",
    "    uniprot_ids=[\"Q5VSL9\", \"Q8W3K0\"],\n",
    "    graph_label_map=g_lab_map,\n",
    "    node_label_map=node_lab_map,\n",
    "    chain_selection_map=chain_selection_map,\n",
    "    graphein_config=gp.ProteinGraphConfig()\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DataBatch(edge_index=[2, 236], node_id=[2], coords=[2], name=[2], dist_mat=[2], num_nodes=238, graph_y=[2], node_y=[20], batch=[238], ptr=[3])\n",
      "Graph labels:  tensor([ 0.5660, -0.7161])\n",
      "Node labels:  tensor([-1.2430,  0.8221, -0.0296, -0.3522,  1.7685, -2.3006, -0.1209, -1.4377,\n",
      "        -1.2816, -0.7039, -0.8580, -0.5647, -1.6848, -1.5069, -2.8355, -0.4000,\n",
      "         0.3203,  0.1497, -1.0708,  0.3418])\n"
     ]
    }
   ],
   "source": [
    "# Create a dataloader from dataset and inspect a batch\n",
    "from torch_geometric.loader import DataLoader\n",
    "\n",
    "dl = DataLoader(ds, batch_size=2, shuffle=True, drop_last=True)\n",
    "for i in dl:\n",
    "    print(i)\n",
    "    print(\"Graph labels: \", i.graph_y)\n",
    "    print(\"Node labels: \", i.node_y)\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Load from local path\n",
    "\n",
    "\n",
    "Creating a ``ProteinGraphDataset`` from a list of full path of pdb files:\n",
    "\n",
    "* ``root/raw`` - Will be empty since no pdb files are downloaded\n",
    "* ``root/processed`` - Contains processed graphs (in ``pytorch_geometric.data.Data`` format) saved as ``$PDB.pt / $UNIPROT_ID.pt``"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['../tests/protein/test_data/1lds.pdb', '../tests/protein/test_data/4hhb.pdb', '../tests/protein/test_data/alphafold_structure.pdb']\n"
     ]
    }
   ],
   "source": [
    "# import sys\n",
    "# sys.path.append('../')  # add system path for python\n",
    "\n",
    "import os \n",
    "from graphein.protein.config import ProteinGraphConfig\n",
    "from graphein.ml import ProteinGraphDataset, ProteinGraphListDataset\n",
    "import torch \n",
    "\n",
    "local_dir = \"../tests/protein/test_data/\"\n",
    "pdb_paths = [os.path.join(local_dir, pdb_path) for pdb_path in os.listdir(local_dir) if pdb_path.endswith(\".pdb\")]\n",
    "print(pdb_paths)\n",
    "\n",
    "# let's load local dataset from local_dir!\n",
    "ds = ProteinGraphDataset(\n",
    "    root = \"../graphein/ml/datasets/test\",\n",
    "    pdb_paths = pdb_paths,\n",
    "    graphein_config=ProteinGraphConfig(),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DataBatch(edge_index=[2, 666], node_id=[2], coords=[2], name=[2], dist_mat=[2], num_nodes=671, batch=[671], ptr=[3])\n"
     ]
    }
   ],
   "source": [
    "# Create a dataloader from dataset and inspect a batch\n",
    "from torch_geometric.loader import DataLoader\n",
    "dl = DataLoader(ds, batch_size=2, shuffle=True, drop_last=True)\n",
    "for i in dl:\n",
    "    print(i)\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### InMemoryProteinGraphDataset\n",
    "\n",
    "[API Reference](https://graphein.ai/modules/graphein.ml.html#graphein.ml.datasets.torch_geometric_dataset.InMemoryProteinGraphDataset)\n",
    "\n",
    "#### Parameters\n",
    "```python\n",
    "InMemoryProteinGraphDataset(\n",
    "        root: str,                                                             \n",
    "        # Root directory where the dataset should be saved.\n",
    "        name: str,                                                             \n",
    "        # Name of the dataset. Will be saved to ``data_$name.pt``.\n",
    "        pdb_paths:Optional[List[str]] =None,\n",
    "        # List of full path of pdb files to load.\n",
    "        pdb_codes: Optional[List[str]] = None,                                 \n",
    "        #  List of PDB codes to download and parse from the PDB.\n",
    "        uniprot_ids: Optional[List[str]] = None,                               \n",
    "        # List of Uniprot IDs to download and parse from Alphafold Database\n",
    "        graph_label_map: Optional[Dict[str, torch.Tensor]] = None,             \n",
    "        # Dictionary mapping PDB/Uniprot IDs to graph-level labels.\n",
    "        node_label_map: Optional[Dict[str, torch.Tensor]] = None,              \n",
    "        # Dictionary mapping PDB/Uniprot IDs to node-level labels.\n",
    "        chain_selection_map: Optional[Dict[str, List[str]]] = None,            \n",
    "        # Dictionary mapping PDB/Uniprot IDs to the desired chains in the PDB files\n",
    "        graphein_config: ProteinGraphConfig = ProteinGraphConfig(),            \n",
    "        # Protein graph construction config\n",
    "        graph_format_convertor: GraphFormatConvertor = GraphFormatConvertor(   \n",
    "            src_format=\"nx\", dst_format=\"pyg\"\n",
    "        ),\n",
    "        # Conversion handler for graphs\n",
    "        graph_transformation_funcs: Optional[List[Callable]] = None,           \n",
    "        # List of functions that consume a nx.Graph and return a nx.Graph. Applied to graphs after construction but before conversion to pyg\n",
    "        transform: Optional[Callable] = None,                                  \n",
    "        # A function/transform that takes in a torch_geometric.data.Data object and returns a transformed version. The data object will be transformed before every access.\n",
    "        pdb_transform: Optional[List[Callable]] = None,\n",
    "        pre_transform: Optional[Callable] = None,                              \n",
    "        # A function/transform that takes in a torch_geometric.data.Data object and returns a transformed version. The data object will be transformed before being saved to disk\n",
    "        pre_filter: Optional[Callable] = None,                                 \n",
    "        # A function that takes in a torch_geometric.data.Data object and returns a boolean value, indicating whether the data object should be included in the final dataset\n",
    "        num_cores: int = 16,                                                   \n",
    "        # Number of cores to use for multiprocessing of graph construction\n",
    "        af_version: int = 2,                                                   \n",
    "        #  Version of AlphaFoldDB structures to use,\n",
    "    )\n",
    "```\n",
    "\n",
    "#### Directory Structure\n",
    "Creating an ``InMemoryProteinGraphDataset`` will create two directories under ``root``:\n",
    "* ``root/raw`` - Contains raw PDB files\n",
    "* ``root/processed`` - Contains processed datasets saved as ``data_{name}.pt``"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing...\n",
      "WARNING:graphein.protein.visualisation:To use the Graphein submodule graphein.protein.visualisation, you need to install: pytorch3d \n",
      "To do so, use the following command: conda install -c pytorch3d pytorch3d\n",
      "WARNING:graphein.protein.visualisation:To use the Graphein submodule graphein.protein.visualisation, you need to install: pytorch3d \n",
      "To do so, use the following command: conda install -c pytorch3d pytorch3d\n",
      "WARNING:graphein.protein.visualisation:To use the Graphein submodule graphein.protein.visualisation, you need to install: pytorch3d \n",
      "To do so, use the following command: conda install -c pytorch3d pytorch3d\n",
      "WARNING:graphein.protein.visualisation:To use the Graphein submodule graphein.protein.visualisation, you need to install: pytorch3d \n",
      "To do so, use the following command: conda install -c pytorch3d pytorch3d\n",
      "WARNING:graphein.protein.visualisation:To use the Graphein submodule graphein.protein.visualisation, you need to install: pytorch3d \n",
      "To do so, use the following command: conda install -c pytorch3d pytorch3d\n",
      "INFO:graphein.protein.graphs:Constructing graph for: ../graphein/ml/datasets/test/raw/4hhb.pdb. Chain selection: A\n",
      "INFO:graphein.protein.graphs:Constructing graph for: ../graphein/ml/datasets/test/raw/Q5VSL9.pdb. Chain selection: all\n",
      "INFO:graphein.protein.graphs:Constructing graph for: ../graphein/ml/datasets/test/raw/1lds.pdb. Chain selection: all\n",
      "INFO:graphein.protein.graphs:Constructing graph for: ../graphein/ml/datasets/test/raw/2ll6.pdb. Chain selection: all\n",
      "INFO:graphein.protein.graphs:Constructing graph for: ../graphein/ml/datasets/test/raw/3eiy.pdb. Chain selection: all\n",
      "DEBUG:graphein.protein.graphs:Deprotonating protein. This removes H atoms from the pdb_df dataframe\n",
      "DEBUG:graphein.protein.graphs:Detected 97 total nodes\n",
      "DEBUG:graphein.protein.graphs:Deprotonating protein. This removes H atoms from the pdb_df dataframe\n",
      "DEBUG:graphein.protein.features.nodes.amino_acid:Reading meiler embeddings from: /Users/arianjamasb/github/graphein/graphein/protein/features/nodes/meiler_embeddings.csv\n",
      "DEBUG:graphein.protein.graphs:Detected 174 total nodes\n",
      "DEBUG:graphein.protein.features.nodes.amino_acid:Reading meiler embeddings from: /Users/arianjamasb/github/graphein/graphein/protein/features/nodes/meiler_embeddings.csv\n",
      "DEBUG:graphein.protein.graphs:Deprotonating protein. This removes H atoms from the pdb_df dataframe\n",
      "DEBUG:graphein.protein.graphs:Detected 141 total nodes\n",
      "DEBUG:graphein.protein.features.nodes.amino_acid:Reading meiler embeddings from: /Users/arianjamasb/github/graphein/graphein/protein/features/nodes/meiler_embeddings.csv\n",
      "DEBUG:graphein.protein.graphs:Deprotonating protein. This removes H atoms from the pdb_df dataframe\n",
      "DEBUG:graphein.protein.graphs:Detected 837 total nodes\n",
      "DEBUG:graphein.protein.features.nodes.amino_acid:Reading meiler embeddings from: /Users/arianjamasb/github/graphein/graphein/protein/features/nodes/meiler_embeddings.csv\n",
      "DEBUG:graphein.protein.graphs:Deprotonating protein. This removes H atoms from the pdb_df dataframe\n",
      "DEBUG:graphein.protein.graphs:Detected 165 total nodes\n",
      "DEBUG:graphein.protein.features.nodes.amino_acid:Reading meiler embeddings from: /Users/arianjamasb/github/graphein/graphein/protein/features/nodes/meiler_embeddings.csv\n",
      "Done!\n"
     ]
    }
   ],
   "source": [
    "from graphein.ml import InMemoryProteinGraphDataset\n",
    "\n",
    "g_lab_map = {\"3eiy\": 1, \"4hhb\": 2, \"Q5VSL9\": 3, \"1lds\": 10, \"2ll6\": 4}\n",
    "node_lab_map = {\"3eiy\": 1, \"4hhb\": 2, \"Q5VSL9\": 3, \"1lds\": 10, \"2ll6\": 4}\n",
    "chain_selection_map = {\"4hhb\": \"A\"}\n",
    "\n",
    "ds = InMemoryProteinGraphDataset(\n",
    "    root = \"../graphein/ml/datasets/test\",\n",
    "    name=\"test\",\n",
    "    pdb_codes=[\"3eiy\", \"4hhb\", \"1lds\", \"2ll6\"],\n",
    "    uniprot_ids=[\"Q5VSL9\"],\n",
    "    graph_label_map=g_lab_map,\n",
    "    node_label_map=node_lab_map,\n",
    "    chain_selection_map=chain_selection_map\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DataBatch(edge_index=[2, 236], node_id=[2], coords=[2], name=[2], dist_mat=[2], graph_y=[2], node_y=[2], num_nodes=238, batch=[238], ptr=[3])\n"
     ]
    }
   ],
   "source": [
    "# Create a dataloader from dataset and inspect a batch\n",
    "dl = DataLoader(ds, batch_size=2, shuffle=True, drop_last=True)\n",
    "for i in dl:\n",
    "    print(i)\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Load from local path\n",
    "\n",
    "\n",
    "Creating an ``InMemoryProteinGraphDataset`` from a list of full path of pdb files:\n",
    "\n",
    "* ``root/raw`` - Will be empty since no pdb files are downloaded\n",
    "* ``root/processed`` - Contains processed datasets saved as ``data_{name}.pt``\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['../tests/protein/test_data/1lds.pdb', '../tests/protein/test_data/4hhb.pdb', '../tests/protein/test_data/alphafold_structure.pdb']\n",
      "Constructing Graphs...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing...\n"
     ]
    },
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.2526402473449707,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": null,
       "postfix": null,
       "prefix": "",
       "rate": null,
       "total": 3,
       "unit": "it",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d5ed353098664f6f803fa502264df986",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/3 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Converting Graphs...\n",
      "Saving Data...\n",
      "Done!\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Done!\n"
     ]
    }
   ],
   "source": [
    "from graphein.ml.datasets.torch_geometric_dataset import InMemoryProteinGraphDataset\n",
    "\n",
    "\n",
    "local_dir = \"../tests/protein/test_data/\"\n",
    "pdb_paths = [os.path.join(local_dir, pdb_path) for pdb_path in os.listdir(local_dir) if pdb_path.endswith(\".pdb\")]\n",
    "print(pdb_paths)\n",
    "\n",
    "# let's load local dataset from local_dir!\n",
    "ds = InMemoryProteinGraphDataset(\n",
    "    root = \"../graphein/ml/datasets/test\",\n",
    "    name = \"test\",\n",
    "    pdb_paths = pdb_paths,\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DataBatch(edge_index=[2, 951], node_id=[2], coords=[2], name=[2], dist_mat=[2], num_nodes=956, batch=[956], ptr=[3])\n"
     ]
    }
   ],
   "source": [
    "# Create a dataloader from dataset and inspect a batch\n",
    "dl = DataLoader(ds, batch_size=2, shuffle=True, drop_last=True)\n",
    "for i in dl:\n",
    "    print(i)\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### ProteinGraphListDataset\n",
    "\n",
    "[API Reference](https://graphein.ai/modules/graphein.ml.html#graphein.ml.datasets.torch_geometric_dataset.ProteinGraphListDataset)\n",
    "\n",
    "The `ProteinGraphListDataset` class is a lightweight class for wrapping a list of pre-computed `pytorch_geometric.data.Data` graphs.\n",
    "\n",
    "#### Parameters\n",
    "\n",
    "```python\n",
    "ProteinGraphListDataset(\n",
    "    root: str,                              # Root directory where the dataset is stored.\n",
    "    data_list: List[Data],                  # List of protein graphs as PyTorch Geometric Data objects.\n",
    "    name: str,                              # Name of dataset. Data will be saved as ``data_{name}.pt``.\n",
    "    transform: Optional[Callable]=None      # A function/transform that takes in a torch_geometric.data.Data object and returns a transformed version. The data object will be transformed before every access.\n",
    "    )\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:graphein.protein.visualisation:To use the Graphein submodule graphein.protein.visualisation, you need to install: pytorch3d \n",
      "To do so, use the following command: conda install -c pytorch3d pytorch3d\n",
      "WARNING:graphein.protein.visualisation:To use the Graphein submodule graphein.protein.visualisation, you need to install: pytorch3d \n",
      "To do so, use the following command: conda install -c pytorch3d pytorch3d\n",
      "WARNING:graphein.protein.visualisation:To use the Graphein submodule graphein.protein.visualisation, you need to install: pytorch3d \n",
      "To do so, use the following command: conda install -c pytorch3d pytorch3d\n",
      "WARNING:graphein.protein.visualisation:To use the Graphein submodule graphein.protein.visualisation, you need to install: pytorch3d \n",
      "To do so, use the following command: conda install -c pytorch3d pytorch3d\n",
      "INFO:graphein.protein.graphs:Constructing graph for: 4hhb. Chain selection: all\n",
      "INFO:graphein.protein.graphs:Constructing graph for: 3eiy. Chain selection: all\n",
      "INFO:graphein.protein.graphs:Constructing graph for: 1lds. Chain selection: all\n",
      "INFO:graphein.protein.graphs:Constructing graph for: 2ll6. Chain selection: all\n",
      "DEBUG:graphein.protein.graphs:Deprotonating protein. This removes H atoms from the pdb_df dataframe\n",
      "DEBUG:graphein.protein.graphs:Deprotonating protein. This removes H atoms from the pdb_df dataframe\n",
      "DEBUG:graphein.protein.graphs:Detected 174 total nodes\n",
      "DEBUG:graphein.protein.graphs:Detected 97 total nodes\n",
      "DEBUG:graphein.protein.features.nodes.amino_acid:Reading meiler embeddings from: /Users/arianjamasb/github/graphein/graphein/protein/features/nodes/meiler_embeddings.csv\n",
      "DEBUG:graphein.protein.features.nodes.amino_acid:Reading meiler embeddings from: /Users/arianjamasb/github/graphein/graphein/protein/features/nodes/meiler_embeddings.csv\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "174\n",
      "97\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "DEBUG:graphein.protein.graphs:Deprotonating protein. This removes H atoms from the pdb_df dataframe\n",
      "DEBUG:graphein.protein.graphs:Detected 574 total nodes\n",
      "DEBUG:graphein.protein.features.nodes.amino_acid:Reading meiler embeddings from: /Users/arianjamasb/github/graphein/graphein/protein/features/nodes/meiler_embeddings.csv\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "574\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "DEBUG:graphein.protein.graphs:Deprotonating protein. This removes H atoms from the pdb_df dataframe\n",
      "DEBUG:graphein.protein.graphs:Detected 165 total nodes\n",
      "DEBUG:graphein.protein.features.nodes.amino_acid:Reading meiler embeddings from: /Users/arianjamasb/github/graphein/graphein/protein/features/nodes/meiler_embeddings.csv\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "165\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "DEBUG:graphein.protein.subgraphs:Found 174 nodes in the chain subgraph.\n",
      "DEBUG:graphein.protein.subgraphs:Creating subgraph from nodes: ['A:VAL:107', 'A:SER:26', 'A:PHE:45', 'A:PRO:53', 'A:VAL:29', 'A:ALA:144', 'A:LEU:39', 'A:GLN:14', 'A:TYR:56', 'A:PRO:69', 'A:MET:117', 'A:ASP:98', 'A:SER:123', 'A:VAL:72', 'A:PRO:60', 'A:ILE:135', 'A:THR:48', 'A:ALA:33', 'A:PHE:78', 'A:ALA:108', 'A:ASN:55', 'A:ARG:87', 'A:GLU:21', 'A:THR:167', 'A:GLU:27', 'A:MET:96', 'A:ALA:119', 'A:VAL:74', 'A:TYR:31', 'A:LEU:63', 'A:VAL:114', 'A:ALA:36', 'A:ASP:168', 'A:GLY:83', 'A:PHE:173', 'A:GLN:25', 'A:TRP:156', 'A:PHE:16', 'A:ASP:71', 'A:LEU:94', 'A:ASP:112', 'A:LEU:12', 'A:GLY:65', 'A:LEU:81', 'A:PRO:23', 'A:VAL:102', 'A:ALA:8', 'A:LYS:136', 'A:GLU:140', 'A:ILE:22', 'A:LYS:35', 'A:ALA:82', 'A:GLU:165', 'A:PHE:139', 'A:LEU:80', 'A:LYS:152', 'A:GLY:38', 'A:GLN:61', 'A:SER:2', 'A:GLY:169', 'A:LEU:106', 'A:ASP:157', 'A:ASN:172', 'A:LYS:10', 'A:VAL:18', 'A:ARG:44', 'A:GLU:99', 'A:ASP:125', 'A:ILE:19', 'A:ILE:159', 'A:GLY:49', 'A:ARG:89', 'A:ILE:59', 'A:TYR:52', 'A:ALA:90', 'A:LYS:95', 'A:VAL:170', 'A:LEU:91', 'A:ASP:68', 'A:GLU:146', 'A:LYS:143', 'A:LYS:132', 'A:PRO:28', 'A:SER:84', 'A:LEU:40', 'A:PRO:77', 'A:GLY:92', 'A:GLY:67', 'A:LYS:30', 'A:LEU:121', 'A:VAL:41', 'A:ILE:124', 'A:VAL:54', 'A:ILE:166', 'A:ALA:171', 'A:GLY:101', 'A:GLY:155', 'A:VAL:85', 'A:LYS:174', 'A:ASN:17', 'A:ILE:20', 'A:ALA:162', 'A:ASP:15', 'A:VAL:86', 'A:MET:93', 'A:LYS:175', 'A:VAL:70', 'A:HIS:163', 'A:LYS:149', 'A:LEU:73', 'A:ALA:161', 'A:LEU:37', 'A:GLY:158', 'A:SER:64', 'A:PRO:128', 'A:ARG:51', 'A:ALA:24', 'A:LYS:105', 'A:LYS:147', 'A:SER:100', 'A:GLN:134', 'A:ILE:75', 'A:ASP:103', 'A:GLY:57', 'A:LEU:145', 'A:LYS:164', 'A:LYS:122', 'A:PHE:58', 'A:GLU:154', 'A:THR:97', 'A:ASP:126', 'A:LEU:131', 'A:ALA:129', 'A:ALA:104', 'A:VAL:109', 'A:THR:118', 'A:LYS:113', 'A:VAL:127', 'A:GLY:148', 'A:TYR:130', 'A:GLY:47', 'A:VAL:153', 'A:PRO:116', 'A:GLU:32', 'A:ASP:43', 'A:CYS:115', 'A:ASN:5', 'A:ASP:34', 'A:ASP:160', 'A:THR:62', 'A:ASN:120', 'A:THR:76', 'A:HIS:111', 'A:ASP:11', 'A:VAL:6', 'A:ASP:133', 'A:ASP:66', 'A:TYR:142', 'A:PRO:7', 'A:PRO:13', 'A:HIS:137', 'A:SER:4', 'A:PHE:3', 'A:MET:50', 'A:VAL:151', 'A:ILE:46', 'A:PRO:79', 'A:PHE:138', 'A:VAL:42', 'A:ALA:88', 'A:GLY:9', 'A:TRP:150', 'A:GLN:141', 'A:PRO:110'].\n",
      "DEBUG:graphein.protein.subgraphs:Found 141 nodes in the chain subgraph.\n",
      "DEBUG:graphein.protein.subgraphs:Creating subgraph from nodes: ['A:VAL:107', 'A:LEU:83', 'A:GLY:22', 'A:GLY:25', 'A:ALA:26', 'A:ALA:123', 'A:LEU:2', 'A:LEU:105', 'A:ALA:130', 'A:LYS:7', 'A:ALA:79', 'A:TYR:24', 'A:ASN:78', 'A:SER:131', 'A:TYR:42', 'A:LEU:100', 'A:LEU:101', 'A:SER:102', 'A:LYS:11', 'A:ALA:69', 'A:SER:35', 'A:HIS:50', 'A:HIS:58', 'A:TYR:140', 'A:HIS:20', 'A:ALA:71', 'A:LEU:136', 'A:PHE:43', 'A:PHE:46', 'A:LEU:66', 'A:VAL:121', 'A:MET:76', 'A:GLU:27', 'A:ALA:120', 'A:HIS:89', 'A:VAL:10', 'A:VAL:93', 'A:ARG:141', 'A:SER:3', 'A:SER:133', 'A:ASP:64', 'A:GLY:51', 'A:GLU:23', 'A:SER:81', 'A:GLN:54', 'A:PRO:95', 'A:THR:38', 'A:HIS:87', 'A:LYS:99', 'A:LYS:90', 'A:ASN:68', 'A:ALA:82', 'A:THR:39', 'A:LYS:139', 'A:THR:108', 'A:HIS:45', 'A:ASP:75', 'A:LEU:80', 'A:SER:124', 'A:VAL:17', 'A:LEU:86', 'A:ALA:13', 'A:LYS:127', 'A:ASP:85', 'A:THR:67', 'A:LEU:106', 'A:LYS:61', 'A:ALA:63', 'A:ASP:47', 'A:ALA:111', 'A:ALA:21', 'A:ALA:12', 'A:LEU:91', 'A:SER:138', 'A:GLU:116', 'A:LEU:48', 'A:GLU:30', 'A:SER:52', 'A:VAL:62', 'A:SER:84', 'A:PRO:77', 'A:GLY:59', 'A:PHE:98', 'A:ALA:19', 'A:ASN:9', 'A:HIS:103', 'A:ASP:6', 'A:ARG:92', 'A:LYS:60', 'A:GLY:18', 'A:PHE:36', 'A:PRO:44', 'A:PRO:4', 'A:ALA:28', 'A:LYS:40', 'A:VAL:96', 'A:THR:134', 'A:HIS:122', 'A:VAL:70', 'A:SER:49', 'A:PRO:119', 'A:THR:137', 'A:LYS:16', 'A:ASN:97', 'A:ARG:31', 'A:VAL:1', 'A:ALA:53', 'A:TRP:14', 'A:ALA:5', 'A:ALA:115', 'A:LEU:34', 'A:GLY:57', 'A:HIS:112', 'A:ALA:65', 'A:ASP:126', 'A:LEU:125', 'A:PRO:37', 'A:HIS:72', 'A:THR:118', 'A:CYS:104', 'A:ASP:94', 'A:THR:8', 'A:PHE:33', 'A:VAL:135', 'A:LYS:56', 'A:LEU:29', 'A:PRO:114', 'A:ASP:74', 'A:LEU:109', 'A:LEU:113', 'A:VAL:132', 'A:GLY:15', 'A:MET:32', 'A:VAL:55', 'A:PHE:128', 'A:ALA:88', 'A:ALA:110', 'A:PHE:117', 'A:VAL:73', 'A:THR:41', 'A:LEU:129'].\n",
      "DEBUG:graphein.protein.subgraphs:Found 97 nodes in the chain subgraph.\n",
      "DEBUG:graphein.protein.subgraphs:Creating subgraph from nodes: ['A:GLU:50', 'A:LEU:64', 'A:PHE:30', 'A:THR:4', 'A:ARG:81', 'A:LEU:39', 'A:ASN:83', 'A:ALA:15', 'A:VAL:9', 'A:ALA:79', 'A:PHE:22', 'A:ASP:96', 'A:ARG:45', 'A:TYR:78', 'A:SER:55', 'A:THR:86', 'A:ARG:12', 'A:LYS:48', 'A:SER:28', 'A:GLY:29', 'A:VAL:93', 'A:PRO:32', 'A:GLU:16', 'A:ILE:35', 'A:SER:57', 'A:SER:33', 'A:ASP:53', 'A:ILE:7', 'A:LYS:19', 'A:LEU:65', 'A:GLU:77', 'A:TRP:60', 'A:SER:20', 'A:TYR:63', 'A:PHE:70', 'A:TRP:95', 'A:HIS:13', 'A:ASP:38', 'A:LEU:54', 'A:PRO:5', 'A:TYR:10', 'A:LEU:87', 'A:THR:68', 'A:HIS:31', 'A:TYR:66', 'A:SER:11', 'A:CYS:25', 'A:SER:52', 'A:VAL:82', 'A:LYS:91', 'A:LEU:23', 'A:LEU:40', 'A:ARG:3', 'A:VAL:27', 'A:GLU:69', 'A:ASP:76', 'A:LYS:75', 'A:THR:73', 'A:PRO:90', 'A:GLY:43', 'A:PHE:56', 'A:VAL:85', 'A:GLY:18', 'A:GLU:44', 'A:ASN:17', 'A:VAL:37', 'A:SER:88', 'A:ASP:59', 'A:GLN:8', 'A:ASN:24', 'A:HIS:84', 'A:CYS:80', 'A:ILE:1', 'A:GLU:36', 'A:HIS:51', 'A:LYS:41', 'A:LYS:6', 'A:VAL:49', 'A:GLU:47', 'A:TYR:67', 'A:PHE:62', 'A:PRO:72', 'A:ASP:34', 'A:GLN:2', 'A:ASN:21', 'A:ILE:92', 'A:SER:61', 'A:MET:0', 'A:PRO:14', 'A:GLN:89', 'A:ILE:46', 'A:GLU:74', 'A:THR:71', 'A:LYS:94', 'A:LYS:58', 'A:ASN:42', 'A:TYR:26'].\n",
      "DEBUG:graphein.protein.subgraphs:Found 148 nodes in the chain subgraph.\n",
      "DEBUG:graphein.protein.subgraphs:Creating subgraph from nodes: ['A:THR:117', 'A:LEU:18', 'A:ALA:1', 'A:GLY:25', 'A:PHE:68', 'A:ASN:111', 'A:LEU:39', 'A:LEU:105', 'A:ALA:15', 'A:THR:26', 'A:ASP:129', 'A:VAL:91', 'A:ALA:147', 'A:GLU:82', 'A:PRO:43', 'A:ASN:53', 'A:MET:124', 'A:ALA:10', 'A:ASP:24', 'A:LYS:13', 'A:ALA:46', 'A:SER:38', 'A:GLN:143', 'A:THR:110', 'A:VAL:121', 'A:MET:76', 'A:ARG:90', 'A:ASP:2', 'A:LYS:115', 'A:GLU:45', 'A:ARG:74', 'A:PHE:16', 'A:GLY:33', 'A:THR:5', 'A:ASP:64', 'A:TYR:99', 'A:PHE:92', 'A:GLU:123', 'A:SER:81', 'A:GLN:49', 'A:GLU:140', 'A:MET:51', 'A:MET:36', 'A:ALA:103', 'A:SER:17', 'A:VAL:136', 'A:THR:29', 'A:ASP:56', 'A:ILE:27', 'A:ARG:86', 'A:ASP:122', 'A:PHE:12', 'A:THR:70', 'A:GLU:120', 'A:LEU:48', 'A:ILE:9', 'A:ALA:128', 'A:LEU:32', 'A:LYS:77', 'A:GLU:139', 'A:GLY:59', 'A:GLU:31', 'A:GLU:127', 'A:GLU:7', 'A:GLY:40', 'A:LYS:75', 'A:LYS:30', 'A:GLU:6', 'A:MET:145', 'A:ALA:57', 'A:ARG:106', 'A:ALA:73', 'A:PHE:141', 'A:GLY:98', 'A:GLY:113', 'A:THR:79', 'A:GLN:135', 'A:ASN:137', 'A:GLU:104', 'A:TYR:138', 'A:ARG:126', 'A:ASN:60', 'A:ASP:78', 'A:GLN:8', 'A:ASN:97', 'A:MET:144', 'A:GLU:114', 'A:GLU:84', 'A:VAL:35', 'A:ASP:131', 'A:PRO:66', 'A:ASP:58', 'A:MET:71', 'A:GLN:41', 'A:ASP:80', 'A:LEU:69', 'A:VAL:142', 'A:GLY:134', 'A:SER:101', 'A:ASP:20', 'A:GLN:3', 'A:THR:34', 'A:LYS:148', 'A:LEU:112', 'A:ASP:22', 'A:PHE:65', 'A:GLU:67', 'A:THR:146', 'A:MET:109', 'A:ILE:130', 'A:THR:44', 'A:MET:72', 'A:GLY:23', 'A:VAL:108', 'A:ILE:63', 'A:LYS:21', 'A:GLU:119', 'A:GLU:47', 'A:ILE:100', 'A:ARG:37', 'A:ASP:118', 'A:PHE:19', 'A:GLY:61', 'A:THR:62', 'A:GLY:132', 'A:LEU:4', 'A:GLU:87', 'A:ALA:102', 'A:ASP:133', 'A:HIS:107', 'A:ASP:95', 'A:GLY:96', 'A:GLU:54', 'A:GLU:83', 'A:ASP:50', 'A:THR:28', 'A:VAL:55', 'A:PHE:89', 'A:ASP:93', 'A:ILE:85', 'A:GLU:14', 'A:ALA:88', 'A:LYS:94', 'A:LEU:116', 'A:ILE:52', 'A:GLU:11', 'A:ASN:42', 'A:ILE:125'].\n",
      "Processing...\n",
      "Done!\n"
     ]
    }
   ],
   "source": [
    "from graphein.ml import ProteinGraphListDataset, GraphFormatConvertor\n",
    "import graphein.protein as gp\n",
    "\n",
    "# Construct graphs\n",
    "graphs = gp.construct_graphs_mp(\n",
    "    pdb_code_it=[\"3eiy\", \"4hhb\", \"1lds\", \"2ll6\"],\n",
    "    return_dict=False\n",
    "    )\n",
    "\n",
    "# do some transformation\n",
    "graphs = [gp.extract_subgraph_from_chains(g, [\"A\"]) for g in graphs]\n",
    "\n",
    "# Convert to PyG Data format\n",
    "convertor = GraphFormatConvertor(src_format=\"nx\", dst_format=\"pyg\")\n",
    "graphs = [convertor(g) for g in graphs]\n",
    "\n",
    "# Create dataset\n",
    "ds = ProteinGraphListDataset(root=\".\", data_list=graphs, name=\"list_test\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data(edge_index=[2, 173], node_id=[174], coords=[1], name=[1], dist_mat=[1], num_nodes=174)\n",
      "Data(edge_index=[2, 140], node_id=[141], coords=[1], name=[1], dist_mat=[1], num_nodes=141)\n",
      "Data(edge_index=[2, 96], node_id=[97], coords=[1], name=[1], dist_mat=[1], num_nodes=97)\n",
      "Data(edge_index=[2, 147], node_id=[148], coords=[1], name=[1], dist_mat=[1], num_nodes=148)\n"
     ]
    }
   ],
   "source": [
    "for i in ds:\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DataBatch(edge_index=[2, 303], node_id=[2], coords=[2], name=[2], dist_mat=[2], graph_y=[2], node_y=[2], num_nodes=306, batch=[306], ptr=[3])\n",
      "DataBatch(edge_index=[2, 1009], node_id=[2], coords=[2], name=[2], dist_mat=[2], graph_y=[2], node_y=[2], num_nodes=1011, batch=[1011], ptr=[3])\n",
      "DataBatch(edge_index=[2, 96], node_id=[1], coords=[1], name=[1], dist_mat=[1], graph_y=[1], node_y=[1], num_nodes=97, batch=[97], ptr=[2])\n"
     ]
    }
   ],
   "source": [
    "# Create a dataloader from dataset and inspect a few batches\n",
    "dl = DataLoader(ds, batch_size=2, shuffle=True, drop_last=False)\n",
    "for i in dl:\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Transforms\n",
    "\n",
    "We can supply various functions to `ProteinGraphDataset` and `InMemoryProteinGraphDataset` to alter the composition of the dataset.\n",
    "\n",
    "* ``pdb_transform`` (``list(callable)``, optional) - A function that receives a list of paths to the downloaded structures. This provides an entry point to apply pre-processing from bioinformatics tools of your choosing\n",
    "\n",
    "* ``graph_transformation_funcs``: (``List[Callable]``, optional) List of functions that consume a ``nx.Graph`` and return a ``nx.Graph``. Applied to graphs after construction but before conversion to ``torch_geometric.data.Data``. Defaults to ``None``.\n",
    "\n",
    "* ``transform`` (``callable``, optional) – A function/transform that takes in a ``torch_geometric.data.Data`` object and returns a transformed version. The data object will be transformed before every access. (default: ``None``)\n",
    "\n",
    "* ``pre_transform`` (``callable``, optional) – A function/transform that takes in a torch_geometric.data.Data object and returns a transformed version. The data object will be transformed before being saved to disk. (default: ``None``)\n",
    "\n",
    "* ``pre_filter`` (``callable,`` optional) – A function that takes in a ``torch_geometric.data.Data`` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: ``None``)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import List\n",
    "import networkx as nx\n",
    "from torch_geometric.data import Data\n",
    "\n",
    "# Create dummy transforms\n",
    "def pdb_transform_fn(files: List[str]):\n",
    "    \"\"\"Transforms raw pdbs prior to computing graphs.\"\"\"\n",
    "    return\n",
    "\n",
    "def graph_transform_fn(graph: nx.Graph) -> nx.Graph:\n",
    "    \"\"\"Transforms graphein nx.Graph prior to conversion to torch_geometric.data.Data.\"\"\"\n",
    "    return graph\n",
    "\n",
    "def transform_fn(data: Data) -> Data:\n",
    "    \"\"\"Transforms torch_geometric.data.Data prior to every access.\"\"\"\n",
    "    return data\n",
    "\n",
    "def pre_transform_fn(data: Data) -> Data:\n",
    "    \"\"\"Transforms torch_geometric.data.Data prior to saving to disk.\"\"\"\n",
    "    return data\n",
    "\n",
    "def pre_filter_fn(data: Data) -> bool:\n",
    "    \"\"\"Takes in a torch_geometric.data.Data and returns True if the data should be included in the dataset.\"\"\"\n",
    "    return True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:graphein.protein.visualisation:To use the Graphein submodule graphein.protein.visualisation, you need to install: pytorch3d \n",
      "To do so, use the following command: conda install -c pytorch3d pytorch3d\n",
      "  0%|          | 0/4 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading PDB structure '3eiy'...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:graphein.protein.utils:Downloaded PDB file for: 3eiy\n",
      " 25%|██▌       | 1/4 [00:01<00:04,  1.52s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading PDB structure '4hhb'...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:graphein.protein.utils:Downloaded PDB file for: 4hhb\n",
      " 50%|█████     | 2/4 [00:03<00:03,  1.61s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading PDB structure '1lds'...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:graphein.protein.utils:Downloaded PDB file for: 1lds\n",
      " 75%|███████▌  | 3/4 [00:04<00:01,  1.51s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading PDB structure '2ll6'...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:graphein.protein.utils:Downloaded PDB file for: 2ll6\n",
      "100%|██████████| 4/4 [00:06<00:00,  1.66s/it]\n",
      "  0%|          | 0/1 [00:00<?, ?it/s]INFO:graphein.protein.utils:Downloaded AlphaFold PDB file for: Q5VSL9\n",
      "100%|██████████| 1/1 [00:00<00:00,  8.83it/s]\n",
      "Processing...\n",
      "WARNING:graphein.protein.visualisation:To use the Graphein submodule graphein.protein.visualisation, you need to install: pytorch3d \n",
      "To do so, use the following command: conda install -c pytorch3d pytorch3d\n",
      "WARNING:graphein.protein.visualisation:To use the Graphein submodule graphein.protein.visualisation, you need to install: pytorch3d \n",
      "To do so, use the following command: conda install -c pytorch3d pytorch3d\n",
      "WARNING:graphein.protein.visualisation:To use the Graphein submodule graphein.protein.visualisation, you need to install: pytorch3d \n",
      "To do so, use the following command: conda install -c pytorch3d pytorch3d\n",
      "WARNING:graphein.protein.visualisation:To use the Graphein submodule graphein.protein.visualisation, you need to install: pytorch3d \n",
      "To do so, use the following command: conda install -c pytorch3d pytorch3d\n",
      "WARNING:graphein.protein.visualisation:To use the Graphein submodule graphein.protein.visualisation, you need to install: pytorch3d \n",
      "To do so, use the following command: conda install -c pytorch3d pytorch3d\n",
      "INFO:graphein.protein.graphs:Constructing graph for: ../graphein/ml/datasets/test/raw/Q5VSL9.pdb. Chain selection: all\n",
      "INFO:graphein.protein.graphs:Constructing graph for: ../graphein/ml/datasets/test/raw/1lds.pdb. Chain selection: all\n",
      "INFO:graphein.protein.graphs:Constructing graph for: ../graphein/ml/datasets/test/raw/3eiy.pdb. Chain selection: all\n",
      "INFO:graphein.protein.graphs:Constructing graph for: ../graphein/ml/datasets/test/raw/4hhb.pdb. Chain selection: A\n",
      "INFO:graphein.protein.graphs:Constructing graph for: ../graphein/ml/datasets/test/raw/2ll6.pdb. Chain selection: all\n",
      "DEBUG:graphein.protein.graphs:Deprotonating protein. This removes H atoms from the pdb_df dataframe\n",
      "DEBUG:graphein.protein.graphs:Detected 97 total nodes\n",
      "DEBUG:graphein.protein.graphs:Deprotonating protein. This removes H atoms from the pdb_df dataframe\n",
      "DEBUG:graphein.protein.features.nodes.amino_acid:Reading meiler embeddings from: /Users/arianjamasb/github/graphein/graphein/protein/features/nodes/meiler_embeddings.csv\n",
      "DEBUG:graphein.protein.graphs:Detected 174 total nodes\n",
      "DEBUG:graphein.protein.features.nodes.amino_acid:Reading meiler embeddings from: /Users/arianjamasb/github/graphein/graphein/protein/features/nodes/meiler_embeddings.csv\n",
      "DEBUG:graphein.protein.graphs:Deprotonating protein. This removes H atoms from the pdb_df dataframe\n",
      "DEBUG:graphein.protein.graphs:Detected 141 total nodes\n",
      "DEBUG:graphein.protein.features.nodes.amino_acid:Reading meiler embeddings from: /Users/arianjamasb/github/graphein/graphein/protein/features/nodes/meiler_embeddings.csv\n",
      "DEBUG:graphein.protein.graphs:Deprotonating protein. This removes H atoms from the pdb_df dataframe\n",
      "DEBUG:graphein.protein.graphs:Detected 837 total nodes\n",
      "DEBUG:graphein.protein.features.nodes.amino_acid:Reading meiler embeddings from: /Users/arianjamasb/github/graphein/graphein/protein/features/nodes/meiler_embeddings.csv\n",
      "DEBUG:graphein.protein.graphs:Deprotonating protein. This removes H atoms from the pdb_df dataframe\n",
      "DEBUG:graphein.protein.graphs:Detected 165 total nodes\n",
      "DEBUG:graphein.protein.features.nodes.amino_acid:Reading meiler embeddings from: /Users/arianjamasb/github/graphein/graphein/protein/features/nodes/meiler_embeddings.csv\n",
      "Done!\n"
     ]
    }
   ],
   "source": [
    "from graphein.ml.datasets.torch_geometric_dataset import InMemoryProteinGraphDataset\n",
    "\n",
    "g_lab_map = {\"3eiy\": 1, \"4hhb\": 2, \"Q5VSL9\": 3, \"1lds\": 10, \"2ll6\": 4}\n",
    "node_lab_map = {\"3eiy\": 1, \"4hhb\": 2, \"Q5VSL9\": 3, \"1lds\": 10, \"2ll6\": 4}\n",
    "chain_selection_map = {\"4hhb\": \"A\"}\n",
    "\n",
    "ds = InMemoryProteinGraphDataset(\n",
    "    root = \"../graphein/ml/datasets/test\",\n",
    "    name=\"test\",\n",
    "    pdb_codes=[\"3eiy\", \"4hhb\", \"1lds\", \"2ll6\"],\n",
    "    uniprot_ids=[\"Q5VSL9\"],\n",
    "    graph_label_map=g_lab_map,\n",
    "    node_label_map=node_lab_map,\n",
    "    chain_selection_map=chain_selection_map,\n",
    "    pdb_transform=[pdb_transform_fn],\n",
    "    graph_transformation_funcs=[graph_transform_fn],\n",
    "    transform=transform_fn,\n",
    "    pre_transform=pre_transform_fn,\n",
    "    pre_filter=pre_filter_fn\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.13 ('base')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "d4d1e4263499bec80672ea0156c357c1ee493ec2b1c70f0acce89fc37c4a6abe"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
